Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for [[clad::non_differentiable]] in reverse mode #916

Merged

Conversation

MihailMihov
Copy link
Collaborator

fixes #717

I've added the code for handing [[clad::non_differentiable]] that is already present in the forward mode visitor to the one for reverse mode. I also modified the tests from forward mode, but ReverseMode/NonDifferentiable.C is currently failing because of an issue with differentiating operator overloads in reverse mode.

Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

@parth-07
Copy link
Collaborator

parth-07 commented Jun 1, 2024

but ReverseMode/NonDifferentiable.C is currently failing because of an issue with differentiating operator overloads in reverse mode.

If we mark that operator overload as non-differentiable, then should the issue still happen?

@MihailMihov
Copy link
Collaborator Author

If we mark that operator overload as non-differentiable, then should the issue still happen?

I just tried changing the test and it still errors, but this time it is too many arguments to function call, expected 1, have 2, while if it tries to differentiate the operator it is too few arguments to function call, expected 4, have 3. I did open an issue for this and it is #917, where I put a more minimal failing test. I can try and look into the case where it fails when non-differentiable, but if I had to guess they are the same issue and fixing one would fix both cases.

@parth-07
Copy link
Collaborator

parth-07 commented Jun 2, 2024

but if I had to guess they are the same issue and fixing one would fix both cases.

Oh, yes, you are right. Thank you for the details.

@vgvassilev vgvassilev force-pushed the reverse-mode-non-differentiable branch from e3fe735 to 22e5c11 Compare June 2, 2024 16:34
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-tidy made some suggestions

.get();
// Creating a zero derivative
auto* zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: argument comment missing for literal argument 'val' [bugprone-argument-comment]

Suggested change
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, /*val=*/0);

@@ -2867,6 +2913,10 @@
"CXXMethodDecl nodes not supported yet!");
MemberExpr* clonedME = utils::BuildMemberExpr(
m_Sema, getCurrentScope(), baseDiff.getExpr(), field->getName());
auto zero =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto zero' can be declared as 'auto *zero' [llvm-qualified-auto]

Suggested change
auto zero =
auto *zero =

@@ -2867,6 +2913,10 @@
"CXXMethodDecl nodes not supported yet!");
MemberExpr* clonedME = utils::BuildMemberExpr(
m_Sema, getCurrentScope(), baseDiff.getExpr(), field->getName());
auto zero =
ConstantFolder::synthesizeLiteral(m_Context.DoubleTy, m_Context, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: argument comment missing for literal argument 'val' [bugprone-argument-comment]

Suggested change
ConstantFolder::synthesizeLiteral(m_Context.DoubleTy, m_Context, 0);
ConstantFolder::synthesizeLiteral(m_Context.DoubleTy, m_Context, /*val=*/0);

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-tidy made some suggestions

if (condVarResult.getDecl_dx())
addToCurrentBlock(BuildDeclStmt(condVarResult.getDecl_dx()));
auto condInit = condVarClone->getInit();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto condInit' can be declared as 'auto *condInit' [llvm-qualified-auto]

Suggested change
auto condInit = condVarClone->getInit();
auto *condInit = condVarClone->getInit();

@vgvassilev
Copy link
Owner

@MihailMihov, can you rebase this pull request?

@MihailMihov MihailMihov force-pushed the reverse-mode-non-differentiable branch 3 times, most recently from 530a10d to 34c325b Compare July 6, 2024 15:45
Copy link
Owner

@vgvassilev vgvassilev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you take care of the clang-tidy suggestions and adapt the new test in such a way that we work around the pre-existing issue of #917?

@MihailMihov MihailMihov force-pushed the reverse-mode-non-differentiable branch 3 times, most recently from aae40ae to 105982c Compare July 6, 2024 18:25
Copy link
Contributor

github-actions bot commented Jul 6, 2024

clang-tidy review says "All clean, LGTM! 👍"

@MihailMihov MihailMihov force-pushed the reverse-mode-non-differentiable branch from 105982c to 7b63512 Compare July 6, 2024 20:46
Copy link
Contributor

github-actions bot commented Jul 6, 2024

clang-tidy review says "All clean, LGTM! 👍"

Copy link

codecov bot commented Jul 6, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 93.94%. Comparing base (f4dcf5c) to head (b590586).

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #916      +/-   ##
==========================================
+ Coverage   93.92%   93.94%   +0.01%     
==========================================
  Files          55       55              
  Lines        8038     8061      +23     
==========================================
+ Hits         7550     7573      +23     
  Misses        488      488              
Files Coverage Δ
lib/Differentiator/ReverseModeVisitor.cpp 97.32% <100.00%> (+0.02%) ⬆️
Files Coverage Δ
lib/Differentiator/ReverseModeVisitor.cpp 97.32% <100.00%> (+0.02%) ⬆️

@MihailMihov MihailMihov force-pushed the reverse-mode-non-differentiable branch from 7b63512 to 7477596 Compare July 15, 2024 21:37
Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

@MihailMihov MihailMihov marked this pull request as ready for review July 15, 2024 22:02
@MihailMihov MihailMihov requested a review from vgvassilev July 15, 2024 22:02
@vgvassilev vgvassilev requested review from vaithak and parth-07 July 16, 2024 14:50
@@ -2954,6 +2977,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
"CXXMethodDecl nodes not supported yet!");
MemberExpr* clonedME = utils::BuildMemberExpr(
m_Sema, getCurrentScope(), baseDiff.getExpr(), field->getName());
auto* zero = ConstantFolder::synthesizeLiteral(m_Context.DoubleTy,
m_Context, /*val=*/0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An IntTy would be more suitable here because we might need to zero-initialize pointers.

Copy link
Collaborator

@vaithak vaithak Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also saw that this is 'DoubleTy` in forward mode. Can you also test that and fix that in a separate PR, if possible?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did change the DoubleTy to an IntTy, but that didn't actually do anything to the tests and pointers didn't work either way. I added another check for them when visiting UO_Deref, not sure if there isn't a better way to fix them however, but now the new test is passing.

SimpleFunctions1() noexcept : x(0), y(0) {}
SimpleFunctions1(double p_x, double p_y) noexcept : x(p_x), y(p_y) {}
double x;
non_differentiable double y;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also test with some pointer member types.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added an fn_s1_field_pointer to the test, is that what you wanted me to test?

@MihailMihov MihailMihov force-pushed the reverse-mode-non-differentiable branch 2 times, most recently from f731f1d to dca4fad Compare July 17, 2024 10:04
Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

test/Gradient/NonDifferentiable.C Outdated Show resolved Hide resolved
test/Gradient/NonDifferentiableError.C Show resolved Hide resolved
// Calling the function without computing derivatives
llvm::SmallVector<Expr*, 4> ClonedArgs;
for (unsigned i = 0, e = CE->getNumArgs(); i < e; ++i)
ClonedArgs.push_back(Clone(CE->getArg(i)));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simply cloning the argument seems incorrect. What if the arguments have side-effect which can affect the derivative computation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that I understand the issue here. It the arguments do have side effects then those would be kept when we clone them, is that not what is expected? When do you think that this wouldn't work correctly?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider an example such as this:

some_non_differentiable_fn_call(r = u * v, s = u + v); 

Now, if we simply clone the arguments then we will not generate adjoint statements for r = u * v and s = u + v.

You don't necessarily need to fix this issue in this PR.

lib/Differentiator/ReverseModeVisitor.cpp Show resolved Hide resolved

// If we have a pointer to a member expression, which is
// non-differentiable, we just return a clone of the original expression.
if (auto* ME = dyn_cast<MemberExpr>(diff.getExpr()))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be handled more uniformly in VisitMemberExpr?

Copy link
Collaborator

@parth-07 parth-07 Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If diff.getExpr_dx() is 0, then we would not need to add a special condition here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be handled more uniformly in VisitMemberExpr?

It is already handled in VisitMemberExpr where we return {clonedME, zero}, but what happened without the above check is that it tries to build something along the lines of *0 += ..., when visiting the UO_Deref.

If diff.getExpr_dx() is 0, then we would need to add a special condition here.

With the above check I eliminate one of the cases where we could end up with a 0 above, if you can think of anything else, then we should handle those too. Do you have anything in mind?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you can think of anything else, then we should handle those too.

I am concerned there might be many such cases... For example, -> operator.

It might be better to test if the diff.getExpr_dx() is a constant (or 0) instead of testing if the member has a non-differentiable attribute. This is because it will help us cover more cases. For example, the adjoint of member expressions of global class objects should also be 0 and consequently they should be handled similarly but they do not have non_differentiable attribute.

@@ -2954,6 +2984,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
"CXXMethodDecl nodes not supported yet!");
MemberExpr* clonedME = utils::BuildMemberExpr(
m_Sema, getCurrentScope(), baseDiff.getExpr(), field->getName());
auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to create zero inside the if-condition as it is only used there.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed that, also what do you think about this being IntTy or DoubleTy. In the forward mode code it was DoubleTy, but Vaibhav suggested changing it to IntTy. It didn't seem to make any difference, but maybe somewhere it will?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it will make a difference anywhere. Clang automatically adds cast nodes to convert 0 to the right type.

// of lambdas is happening in the `VisitCallExpr`. For now, only the
// declarations with lambda expressions without captures are supported.
isLambda = typeDecl && typeDecl->isLambda();
if (isLambda ||
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a test for a local variable declaration with non_differentiable attribute?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added fn_non_diff_var to Gradient/NonDifferentiable.C test, but I believe it's not working as expected. The correct output would be 0.00 0.00 right? I'll try to get that fixed now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please create an issue for this and resolve it in a follow-up pull-request?

@MihailMihov MihailMihov force-pushed the reverse-mode-non-differentiable branch 2 times, most recently from 84d2de1 to 635e36d Compare July 18, 2024 10:09
Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

@MihailMihov MihailMihov force-pushed the reverse-mode-non-differentiable branch 2 times, most recently from 0083191 to cd177a4 Compare July 18, 2024 18:13
Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

@vgvassilev vgvassilev force-pushed the reverse-mode-non-differentiable branch from cd177a4 to f4780e9 Compare July 20, 2024 15:58
Copy link
Collaborator

@parth-07 parth-07 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good as an initial support for [[non_differentiable]]. Thank you for working on this.

Can you please open issues for the comments which need some future work:

  • Differentiating argument expressions of non differentiable function calls. This is required because argument expressions can have side-affects.
  • Supporting non differentiate attribute on local variables.
  • Improve handling of non_differentiable variables in expressions such as * (dereference operator), -> and so on.

@vgvassilev
Copy link
Owner

Can you squash the tests into the other commit?

Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

@MihailMihov MihailMihov force-pushed the reverse-mode-non-differentiable branch from f4780e9 to bd386d6 Compare July 20, 2024 21:56
Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

@MihailMihov
Copy link
Collaborator Author

Can you please open issues for the comments which need some future work:

* Differentiating argument expressions of non differentiable function calls. This is required because argument expressions can have side-affects.

* Supporting non differentiate attribute on local variables.

* Improve handling of non_differentiable variables in expressions such as `*` (dereference operator), `->` and so on.

I opened issues for 1 and 3. For 2 this PR includes a basic test and fix, but more work may be necessary. Do I just create an issue saying that there might be a cleaner fix or something more specific?

@MihailMihov MihailMihov force-pushed the reverse-mode-non-differentiable branch from bd386d6 to f7c65a8 Compare July 20, 2024 22:14
Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

@vgvassilev
Copy link
Owner

Can you please open issues for the comments which need some future work:

* Differentiating argument expressions of non differentiable function calls. This is required because argument expressions can have side-affects.

* Supporting non differentiate attribute on local variables.

* Improve handling of non_differentiable variables in expressions such as `*` (dereference operator), `->` and so on.

I opened issues for 1 and 3. For 2 this PR includes a basic test and fix, but more work may be necessary. Do I just create an issue saying that there might be a cleaner fix or something more specific?

Yes, we should create an issue describing what this "more work may be necessary" means in technical and practical terms.

@vgvassilev vgvassilev force-pushed the reverse-mode-non-differentiable branch from f7c65a8 to b590586 Compare July 21, 2024 18:32
Copy link
Contributor

clang-tidy review says "All clean, LGTM! 👍"

@vgvassilev vgvassilev merged commit 28dea37 into vgvassilev:master Jul 21, 2024
89 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support [[clad::non_differentiable]] in the reverse-mode
4 participants